
import jax
import jax.numpy as jnp
import jax.random as random
import pickle
import data
import model
import options
import tabulate
from evaluation import invertibility,low_dimensionality,isometry
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from sklearn.manifold import Isomap
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import KFold


def reconstruct_geodesic(pred, adj, i, j):
    if pred[i][j] != -9999: 
        path = [j]
        while pred[i][path[-1]] != i:
            path.append(pred[i][path[-1]])
        path.append(i)
        path.reverse()
        distances = jnp.diagonal(adj[path][:,path],offset=1)
        distances = (distances/distances.sum()).tolist()
    else:
        path = []
        distances = []
    return jnp.array(path),distances

def ablation_experiment():

    isometry_args = options.get_isometry_args()

    x,mean = data.load_data(debug=True)
    isomap = Isomap(n_neighbors=isometry_args.n_neighbors).fit(x)

    key = random.PRNGKey(isometry_args.seed)
    key, rs_key, dataloader_key, init_key = random.split(key,4)
    train_idx = random.choice(rs_key,x.shape[0],(int(x.shape[0]*0.8),),replace=False)
    test_idx = jnp.setdiff1d(jnp.arange(x.shape[0]), train_idx)

    with open("models/ablation/run-20240728_122548-3wlvv2s4/files/parameters", 'rb') as fp:
        varphi_pars_none = pickle.load(fp)

    with open("models/ablation/run-20240728_122551-5wcnovvk/files/parameters", 'rb') as fp:
        varphi_pars_gm = pickle.load(fp)

    with open("models/ablation/run-20240728_122550-x938cx0r/files/parameters", 'rb') as fp:
        varphi_pars_stability = pickle.load(fp)

    with open("models/ablation/run-20240728_122553-bpefnss6/files/parameters", 'rb') as fp:
        varphi_pars_both = pickle.load(fp)

    varphi = model.CNF(x_dim=x.shape[1],hidden_nf=isometry_args.hidden_nf,n_layers=isometry_args.n_layers,n_steps=isometry_args.n_steps,seed=isometry_args.seed)

    varphi_z_none,_,_,_ = varphi.apply(varphi_pars_none,x[test_idx])
    varphi_x_approx_none,_ = varphi.apply(varphi_pars_none,varphi_z_none,method="inverse")
    varphi_z_gm,_,_,_ = varphi.apply(varphi_pars_gm,x[test_idx])
    varphi_x_approx_gm,_ = varphi.apply(varphi_pars_gm,varphi_z_gm,method="inverse")
    varphi_z_stability,_,_,_ = varphi.apply(varphi_pars_stability,x[test_idx])
    varphi_x_approx_stability,_ = varphi.apply(varphi_pars_stability,varphi_z_stability,method="inverse")
    varphi_z_both,_,_,_ = varphi.apply(varphi_pars_both,x[test_idx])
    varphi_x_approx_both,_ = varphi.apply(varphi_pars_both,varphi_z_both,method="inverse")

    inv_none = invertibility(x[test_idx],varphi_x_approx_none)
    inv_gm = invertibility(x[test_idx],varphi_x_approx_gm)
    inv_stability = invertibility(x[test_idx],varphi_x_approx_stability)
    inv_both = invertibility(x[test_idx],varphi_x_approx_both)

    ldim_none = low_dimensionality(varphi_z_none,isometry_args.n_dim)
    ldim_gm = low_dimensionality(varphi_z_gm,isometry_args.n_dim)
    ldim_stability = low_dimensionality(varphi_z_stability,isometry_args.n_dim)
    ldim_both = low_dimensionality(varphi_z_both,isometry_args.n_dim)

    isom_none = isometry(varphi_z_none,isomap.dist_matrix_[test_idx][:,test_idx],isomap.dist_matrix_.max())
    isom_gm = isometry(varphi_z_gm,isomap.dist_matrix_[test_idx][:,test_idx],isomap.dist_matrix_.max())
    isom_stability = isometry(varphi_z_stability,isomap.dist_matrix_[test_idx][:,test_idx],isomap.dist_matrix_.max())
    isom_both = isometry(varphi_z_both,isomap.dist_matrix_[test_idx][:,test_idx],isomap.dist_matrix_.max())

    table = [["Model","Invertibility","Low Dimensionality","Isometry"],
             ["None",inv_none,ldim_none,isom_none],
             ["GM",inv_gm,ldim_gm,isom_gm],
             ["Stability",inv_stability,ldim_stability,isom_stability],
             ["Both",inv_both,ldim_both,isom_both]]
    formatted_table = tabulate.tabulate(table,headers="firstrow",tablefmt="latex_raw",floatfmt=(None, ".3e", ".3e", ".3e"))

    with open("plots/ablation_experiment.txt", "w") as f:
        f.write(formatted_table)

def interpolation_experiment():

    isometry_args = options.get_isometry_args()
    vae_args = options.get_VAE_args()

    x,mean = data.load_data(debug=True)
    isomap = Isomap(n_neighbors=isometry_args.n_neighbors).fit(x)

    key = random.PRNGKey(isometry_args.seed)
    key, rs_key, pairs_key, init_key = random.split(key,4)
    train_idx = random.choice(rs_key,x.shape[0],(int(x.shape[0]*0.8),),replace=False)
    test_idx = jnp.setdiff1d(jnp.arange(x.shape[0]), train_idx)

    with open("models/interpolation/run-20240728_130602-cxjuyx7v/files/parameters", 'rb') as fp:
        varphi_pars = pickle.load(fp)

    with open("models/interpolation/run-20240729_143000-mttkscbg/files/parameters", 'rb') as fp:
        vae_pars = pickle.load(fp)

    with open("models/interpolation/run-20240729_143000-u2fzed02/files/parameters", 'rb') as fp:
        beta_vae_pars = pickle.load(fp)

    # Print number of parameters    
    print("Number of parameters isometry: ", sum([p.size for p in jax.tree.leaves(varphi_pars)]))
    print("Number of parameters VAE: ", sum([p.size for p in jax.tree.leaves(vae_pars)]))
    print("Number of parameters VAE ($\\beta=10.0$): ", sum([p.size for p in jax.tree.leaves(beta_vae_pars)]))
    
    varphi = model.CNF(x_dim=x.shape[1],hidden_nf=isometry_args.hidden_nf,n_layers=isometry_args.n_layers,n_steps=isometry_args.n_steps,seed=isometry_args.seed)
    vae = model.VAE(x_dim=x.shape[1],n_encoder_layers=vae_args.n_layers,n_decoder_layers=vae_args.n_layers,hidden_nf=vae_args.hidden_nf,latent_nf=vae_args.n_dim)
    varphi_z,_,_,_ = varphi.apply(varphi_pars,x)
    vae_z,_ = vae.apply(vae_pars,x,method="encode")
    beta_vae_z,_ = vae.apply(beta_vae_pars,x,method="encode")

    @jax.jit
    def geodesic_isometry(varphi_z, varphi_pars, start_geo, end_geo, T):
        varphi_lat_geo =  varphi_z[start_geo] + (varphi_z[end_geo] - varphi_z[start_geo])*T[:,None]
        varphi_proj_lat_geo = jnp.concatenate([jnp.zeros((T.shape[0],x.shape[1]-isometry_args.n_dim)),varphi_lat_geo[:,x.shape[1]-isometry_args.n_dim:]],axis=1)
        varphi_geo,_ = varphi.apply(varphi_pars,varphi_lat_geo,method="inverse")
        varphi_proj_geo,_ = varphi.apply(varphi_pars,varphi_proj_lat_geo,method="inverse")
        return varphi_geo,varphi_proj_geo,varphi_lat_geo,varphi_proj_lat_geo
    
    @jax.jit
    def geodesic_VAE(vae_z, vae_pars, start_geo, end_geo, T):
        vae_lat_geo =  vae_z[start_geo] + (vae_z[end_geo] - vae_z[start_geo])*T[:,None]
        vae_geo,_ = vae.apply(vae_pars,vae_lat_geo,method="decode")
        return vae_geo
    
    n_random_geodesics = 100
    
    geo_rmse_li = jnp.zeros((n_random_geodesics,))
    geo_rmse_varphi = jnp.zeros((n_random_geodesics,))
    geo_rmse_varphi_proj = jnp.zeros((n_random_geodesics,))
    geo_rmse_vae = jnp.zeros((n_random_geodesics,))
    geo_rmse_beta_vae = jnp.zeros((n_random_geodesics,))

    all_pairs = jnp.array([(i,j) for i in range(test_idx.shape[0]) for j in range(i) if i != j])
    shuffled_indices = jax.random.permutation(pairs_key, all_pairs.shape[0])
    random_geo_pairs = all_pairs[shuffled_indices[:n_random_geodesics]]

    geodesic_lengths = jnp.zeros((n_random_geodesics,))
    
    for idx,(test_geo_start,test_geo_end) in enumerate(tqdm(random_geo_pairs, desc="Processing pairs")):

        start_geo = test_idx[test_geo_start]
        end_geo = test_idx[test_geo_end]
        path, dists = reconstruct_geodesic(isomap.predecessors,isomap.dist_matrix_,start_geo,end_geo)
        T = [0]
        for i in range(len(dists)):
            T.append(T[-1]+dists[i])
        T = jnp.array(T)

        li_geo = (x[start_geo] + (x[end_geo] - x[start_geo])*T[:,None])
        varphi_geo,varphi_proj_geo,_,_ = geodesic_isometry(varphi_z, varphi_pars, start_geo, end_geo, T)
        vae_geo = geodesic_VAE(vae_z, vae_pars, start_geo, end_geo, T)
        beta_vae_geo = geodesic_VAE(beta_vae_z, beta_vae_pars, start_geo, end_geo, T)

        rmse_li_geo = jnp.sqrt(jnp.mean((li_geo-x[path])**2))
        rmse_varphi_geo = jnp.sqrt(jnp.mean((varphi_geo-x[path])**2))
        rmse_varphi_proj_geo = jnp.sqrt(jnp.mean((varphi_proj_geo-x[path])**2))
        rmse_vae_geo = jnp.sqrt(jnp.mean((vae_geo-x[path])**2))
        rmse_beta_vae_geo = jnp.sqrt(jnp.mean((beta_vae_geo-x[path])**2))

        geo_rmse_li = geo_rmse_li.at[idx].set(rmse_li_geo)
        geo_rmse_varphi = geo_rmse_varphi.at[idx].set(rmse_varphi_geo)
        geo_rmse_varphi_proj = geo_rmse_varphi_proj.at[idx].set(rmse_varphi_proj_geo)
        geo_rmse_vae = geo_rmse_vae.at[idx].set(rmse_vae_geo)
        geo_rmse_beta_vae = geo_rmse_beta_vae.at[idx].set(rmse_beta_vae_geo)
        geodesic_lengths = geodesic_lengths.at[idx].set(isomap.dist_matrix_[start_geo][end_geo])

    table = [["Model","Mean RMSE ($\AA$)","Std RMSE ($\AA$)"],
            ["Linear Interpolation",jnp.mean(geo_rmse_li),jnp.std(geo_rmse_li)],
            ["$(\cdot,\cdot)^{\mathcal{M}}$",jnp.mean(geo_rmse_varphi),jnp.std(geo_rmse_varphi)],
            ["$(\cdot,\cdot)^{\mathcal{M}_{d prime}}$",jnp.mean(geo_rmse_varphi_proj),jnp.std(geo_rmse_varphi_proj)],
            ["VAE",jnp.mean(geo_rmse_vae),jnp.std(geo_rmse_vae)],
            ["VAE ($\\beta=10.0$)",jnp.mean(geo_rmse_beta_vae),jnp.std(geo_rmse_beta_vae)]]

    formatted_table = tabulate.tabulate(table, headers="firstrow", tablefmt="latex_raw",floatfmt=(None, ".3e", ".3e"))

    with open("plots/interpolation_experiment.txt", "w") as f:
        f.write(formatted_table)

    start_geo = test_idx[3]
    end_geo = test_idx[98]
    path, dists = reconstruct_geodesic(isomap.predecessors,isomap.dist_matrix_,start_geo,end_geo)

    T = [0]
    for i in range(len(dists)):
        T.append(T[-1]+dists[i])
    T = jnp.array(T)

    li_geo = (x[start_geo] + (x[end_geo] - x[start_geo])*T[:,None])

    varphi_geo,varphi_proj_geo,varphi_lat_geo,varphi_proj_lat_geo = geodesic_isometry(varphi_z, varphi_pars, start_geo, end_geo, T)

    vae_geo = geodesic_VAE(vae_z, vae_pars, start_geo, end_geo, T)
    beta_vae_geo = geodesic_VAE(beta_vae_z, beta_vae_pars, start_geo, end_geo, T)

    rmse_li_geo = jnp.sqrt(jnp.mean((li_geo-x[path])**2,axis=-1))
    rmse_varphi_geo = jnp.sqrt(jnp.mean((varphi_geo-x[path])**2,axis=-1))
    rmse_varphi_proj_geo = jnp.sqrt(jnp.mean((varphi_proj_geo-x[path])**2,axis=-1))
    rmse_vae_geo = jnp.sqrt(jnp.mean((vae_geo-x[path])**2,axis=-1))
    rmse_beta_vae_geo = jnp.sqrt(jnp.mean((beta_vae_geo-x[path])**2,axis=-1))

    n = 100
    temp = jnp.linspace(-1,1,n)
    manifold = jnp.stack([-jnp.sin(0.5*jnp.pi*temp),jnp.cos(0.5*jnp.pi*temp)],axis=1) 
    manifold = manifold - jnp.mean(manifold,axis=0)
    manifold_lat,_,_,_ = varphi.apply(varphi_pars,manifold)

    geodesic_example_fig = plt.figure(2,figsize=(5,5))
    plt.scatter(x[:,0],x[:,1],color="tab:blue",alpha=0.7,edgecolors="none")
    plt.plot(manifold[:,0],manifold[:,1],label='Manifold',color="black")
    plt.plot(x[path,0],x[path,1],label='Isomap Geodesic',color="tab:orange")
    plt.plot(varphi_geo[:,0],varphi_geo[:,1],label='RAE',color="tab:red")
    plt.scatter(x[start_geo,0],x[start_geo,1],color="tab:orange")
    plt.scatter(x[end_geo,0],x[end_geo,1],color="tab:orange")
    plt.xlim([-1.5,1.5])
    plt.ylim([-1.,1.])
    plt.xlabel(r'$\boldsymbol{x}_{1}$')
    plt.ylabel(r'$\boldsymbol{x}_{2}$')
    plt.tight_layout()
    plt.savefig("plots/interpolation_geodesic_example.png",dpi=300,transparent=True)
    plt.close()



def one_NN_classification(data, samples, n_splits=5):
    data_labels = jnp.zeros(len(data))
    generated_labels = jnp.ones(len(samples))
    combined_labels = jnp.concatenate((data_labels, generated_labels))
    combined_data = jnp.concatenate((data, samples), axis=0)

    kf = KFold(n_splits=n_splits)
    accuracies = []

    for train_index, test_index in kf.split(combined_data):
        train_data, test_data = combined_data[train_index], combined_data[test_index]
        train_labels, test_labels = combined_labels[train_index], combined_labels[test_index]
        knn = KNeighborsClassifier(n_neighbors=1)
        knn.fit(train_data, train_labels)
        predictions = jnp.array(knn.predict(test_data))
        accuracy = jnp.mean(predictions == test_labels)
        accuracies.append(accuracy)
    accuracies = jnp.array(accuracies)

    return jnp.mean(accuracies)


def generation_experiment():

    cfm_args = options.get_CFM_args()
    pfm_args = options.get_PFM_args()
    d_pfm_args = options.get_d_PFM_args()
    
    key = random.PRNGKey(cfm_args.seed)
    key, rs_key, dataloader_key, init_key = random.split(key,4)
    
    x, mean = data.load_data(debug=True)
    manifold_fig, ax = plt.subplots(figsize=(5,5))
    ax.cla()
    plt.scatter(x[:,0],x[:,1],alpha=0.5,color="tab:blue")
    plt.xlim(-2.5,2.5)
    plt.ylim(-1.2,1.2)
    ax.axis('off')
    manifold_fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)
    plt.savefig("plots/x.png",dpi=300)
    plt.close()

    x_normalized = (x - x.mean(axis=0))/x.std(axis=0)

    if cfm_args.split:
        train_idx = random.choice(rs_key,x.shape[0],(int(x.shape[0]*cfm_args.train_size),),replace=False)
        test_idx = jnp.setdiff1d(jnp.arange(x.shape[0]), train_idx)
    else:
        train_idx = jnp.arange(x.shape[0])
        test_idx = train_idx

    varphi = model.CNF(x_dim=x.shape[1],hidden_nf=pfm_args.varphi_hidden_nf,n_layers=pfm_args.varphi_n_layers,n_steps=pfm_args.varphi_n_steps,seed=pfm_args.seed)
    cfm_vf = model.VectorField(x_dim=x.shape[1],hidden_nf=cfm_args.hidden_nf,n_layers=cfm_args.n_layers)
    pfm_vf = model.VectorField(x_dim=x.shape[1],hidden_nf=pfm_args.hidden_nf,n_layers=pfm_args.n_layers)
    d_pfm_vf = model.VectorField(x_dim=d_pfm_args.n_dim,hidden_nf=d_pfm_args.hidden_nf,n_layers=d_pfm_args.n_layers)

    with open(pfm_args.path, 'rb') as fp:
        varphi_pars = pickle.load(fp)

    with open("models/generation/run-20240728_141230-16wi06qg/files/parameters", 'rb') as fp:
        cfm_pars = pickle.load(fp)

    with open("models/generation/run-20240728_141441-jydz5v0j/files/parameters", 'rb') as fp:
        pfm_pars = pickle.load(fp)

    with open("models/generation/run-20241001_215702-m5c8z97j/files/parameters", 'rb') as fp:
        d_pfm_pars = pickle.load(fp)

    print("Number of parameters CFM: ", sum([p.size for p in jax.tree.leaves(cfm_pars)]))
    print("Number of parameters PFM: ", sum([p.size for p in jax.tree.leaves(pfm_pars)]))
    print("Number of parameters d-PFM: ", sum([p.size for p in jax.tree.leaves(d_pfm_pars)]))

    cfm_cnf = model.FMCNF(cfm_vf,n_steps=20,seed=cfm_args.seed,scheme="rk4")
    cfm_cnf_pars = cfm_cnf.init(init_key,x[train_idx[:10]])
    cfm_cnf_pars["params"]["vector_field"] = cfm_pars["params"]

    pfm_cnf = model.FMCNF(pfm_vf,n_steps=20,seed=pfm_args.seed,scheme="rk4")
    pfm_cnf_pars = pfm_cnf.init(init_key,x[train_idx[:10]])
    pfm_cnf_pars["params"]["vector_field"] = pfm_pars["params"]

    d_pfm_cnf = model.FMCNF(d_pfm_vf,n_steps=20,seed=d_pfm_args.seed,scheme="rk4")
    d_pfm_cnf_pars = d_pfm_cnf.init(init_key,x[train_idx[:10]])
    d_pfm_cnf_pars["params"]["vector_field"] = d_pfm_pars["params"]

    z,_,_,jac_trace_est = varphi.apply(varphi_pars,x)
    plt.scatter(z[test_idx,0],z[test_idx,1],s=1)
    plt.savefig("plots/z_test.png")
    plt.close()
    z_normalized = (z - z.mean(axis=0))/z.std(axis=0)
    plt.scatter(z_normalized[test_idx,0],z_normalized[test_idx,1],s=1)
    plt.savefig("plots/z_normalized_test.png")
    plt.close()
    
    n_samples = 500
    n_sim_steps_list = [2,4,6,8,16]
    cfm_1nn_accuracies = []
    pfm_1nn_accuracies = []
    d_pfm_1nn_accuracies = []

    for n_sim_steps in n_sim_steps_list:

        cfm_traj = cfm_cnf.apply(cfm_cnf_pars,n_samples,n_sim_steps,method="simulate")*x.std(axis=0) + x.mean(axis=0)
        z_pfm_traj = pfm_cnf.apply(pfm_cnf_pars,n_samples,n_sim_steps,method="simulate").reshape(-1,z.shape[1])*z.std(axis=0) + z.mean(axis=0)
        pfm_traj = varphi.apply(varphi_pars,z_pfm_traj,method="inverse")[0].reshape(-1,n_samples,x.shape[1])
        ld_z_d_pfm_traj = d_pfm_cnf.apply(d_pfm_cnf_pars,n_samples,n_sim_steps,method="simulate").reshape(-1,d_pfm_args.n_dim)
        z_d_pfm_traj = jnp.concatenate([jnp.zeros((ld_z_d_pfm_traj.shape[0],x.shape[1]-d_pfm_args.n_dim)),ld_z_d_pfm_traj],axis=1)*z.std(axis=0) + z.mean(axis=0)
        d_pfm_traj = varphi.apply(varphi_pars,z_d_pfm_traj,method="inverse")[0].reshape(-1,n_samples,x.shape[1])

        cfm_1nn_accuracy = one_NN_classification(x[test_idx],cfm_traj[-1][test_idx])
        pfm_1nn_accuracy = one_NN_classification(x[test_idx],pfm_traj[-1][test_idx])
        d_pfm_1nn_accuracy = one_NN_classification(x[test_idx],d_pfm_traj[-1][test_idx])

        cfm_1nn_accuracies.append(cfm_1nn_accuracy)
        pfm_1nn_accuracies.append(pfm_1nn_accuracy)
        d_pfm_1nn_accuracies.append(d_pfm_1nn_accuracy)

    table = [["Model", str(n_sim_steps_list[0])+" Steps",str(n_sim_steps_list[1])+" Steps",str(n_sim_steps_list[2])+" Steps",str(n_sim_steps_list[3])+" Steps",str(n_sim_steps_list[4])+" Steps"],
            ["CFM",cfm_1nn_accuracies[0],cfm_1nn_accuracies[1],cfm_1nn_accuracies[2],cfm_1nn_accuracies[3],cfm_1nn_accuracies[4]],
            ["PFM",pfm_1nn_accuracies[0],pfm_1nn_accuracies[1],pfm_1nn_accuracies[2],pfm_1nn_accuracies[3],pfm_1nn_accuracies[4]],
            ["d-PFM",d_pfm_1nn_accuracies[0],d_pfm_1nn_accuracies[1],d_pfm_1nn_accuracies[2],d_pfm_1nn_accuracies[3],d_pfm_1nn_accuracies[4]]]
    
    formatted_table = tabulate.tabulate(table,headers="firstrow",tablefmt="latex_raw",floatfmt=(None,".3",".3",".3",".3",".3"))

    with open("plots/generation_1nn_accuracy.txt", "w") as f:
        f.write(formatted_table)


    plt.ioff()

    fig, ax = plt.subplots(figsize=(5,5))
    def animate(i):
        ax.cla()
        ax.scatter(cfm_traj[i][:,0],cfm_traj[i][:,1],alpha=0.5,color="tab:blue",label='CFM')
        ax.set_xlim(-2.5,2.5)
        ax.set_ylim(-1.2,1.2)
        ax.axis('off')
        fig.patch.set_alpha(0)
        ax.patch.set_alpha(0)
        return fig
    
    ani = animation.FuncAnimation(fig, animate, frames=len(cfm_traj))
    ani.save('plots/generation_cfm_traj.html', writer='html', fps=60, dpi=300)

    fig, ax = plt.subplots(figsize=(5,5))
    def animate(i):
        ax.cla()
        ax.scatter(pfm_traj[i][:,0],pfm_traj[i][:,1],alpha=0.5,color="tab:blue",label='PFM')
        ax.set_xlim(-2.5,2.5)
        ax.set_ylim(-1.2,1.2)
        ax.axis('off')
        fig.patch.set_alpha(0)
        ax.patch.set_alpha(0)
        return fig
    
    ani = animation.FuncAnimation(fig, animate, frames=len(pfm_traj))
    ani.save('plots/generation_pfm_traj.html', writer='html', fps=60, dpi=300)

    fig, ax = plt.subplots(figsize=(5,5))
    def animate(i):
        ax.cla()
        ax.scatter(z_pfm_traj.reshape(-1,n_samples,x.shape[1])[i][:,0],z_pfm_traj.reshape(-1,n_samples,x.shape[1])[i][:,1],alpha=0.5,color="tab:blue",label='CFM')
        ax.set_xlim(-2.5,2.5)
        ax.set_ylim(-1.2,1.2)
        ax.axis('off')
        fig.patch.set_alpha(0)
        ax.patch.set_alpha(0)
        return fig
    
    ani = animation.FuncAnimation(fig, animate, frames=len(pfm_traj))
    ani.save('plots/generation_z_pfm_traj.html', writer='html', fps=60, dpi=300)

    fig, ax = plt.subplots(figsize=(5,5))
    def animate(i):
        ax.cla()
        ax.scatter(d_pfm_traj[i][:,0],d_pfm_traj[i][:,1],alpha=0.5,color="tab:blue",label='d-PFM')
        ax.set_xlim(-2.5,2.5)
        ax.set_ylim(-1.2,1.2)
        ax.axis('off')
        fig.patch.set_alpha(0)
        ax.patch.set_alpha(0)
        return fig
    
    ani = animation.FuncAnimation(fig, animate, frames=len(d_pfm_traj))
    ani.save('plots/generation_d_pfm_traj.html', writer='html', fps=60, dpi=300)

print("Start Ablation Experiments for ARCH Dataset")
ablation_experiment()
print("Start Interpolation Experiments for ARCH Dataset")
interpolation_experiment()
print("Start Generation Experiments for ARCH Dataset")
generation_experiment()